/*
Copyright 2021 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package asthelpergen

import (
	"go/types"
	"strings"

	"github.com/dave/jennifer/jen"
)

const Comparator = "Comparator"

// EqualsOptions configures the equals generator behavior.
type EqualsOptions struct {
	// AllowCustom specifies types that can have custom equality comparators.
	// For these types, the generated Comparator struct will include function fields
	// that allow custom comparison logic to be injected at runtime.
	AllowCustom []string
}

type equalsGen struct {
	file        *jen.File
	comparators map[string]types.Type
}

var _ generator = (*equalsGen)(nil)

func newEqualsGen(pkgname string, options *EqualsOptions) *equalsGen {
	file := jen.NewFile(pkgname)
	file.HeaderComment(licenseFileHeader)
	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")

	customComparators := make(map[string]types.Type, len(options.AllowCustom))
	for _, tt := range options.AllowCustom {
		customComparators[tt] = nil
	}

	return &equalsGen{
		file:        file,
		comparators: customComparators,
	}
}

func (e *equalsGen) addFunc(name string, code *jen.Statement) {
	e.file.Add(jen.Comment(name + " does deep equals between the two objects."))
	e.file.Add(code)
}

func (e *equalsGen) customComparatorField(t types.Type) string {
	return printableTypeName(t) + "_"
}

func (e *equalsGen) genFile(generatorSPI) (string, *jen.File) {
	e.file.Type().Id(Comparator).StructFunc(func(g *jen.Group) {
		for tname, t := range e.comparators {
			if t == nil {
				continue
			}
			method := e.customComparatorField(t)
			g.Add(jen.Id(method).Func().Call(jen.List(jen.Id("a"), jen.Id("b")).Id(tname)).Bool())
		}
	})

	return "ast_equals.go", e.file
}

func (e *equalsGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
	/*
		func (cmp *Comparator) AST(inA, inB AST) bool {
			if inA == inB {
				return true
			}
			if inA == nil || inB8 == nil {
				return false
			}
			switch a := inA.(type) {
			case *SubImpl:
				b, ok := inB.(*SubImpl)
				if !ok {
					return false
				}
				return cmp.SubImpl(a, b)
			}
			return false
		}
	*/
	stmts := []jen.Code{
		jen.If(jen.Id("inA == nil").Op("&&").Id("inB == nil")).Block(jen.Return(jen.True())),
		jen.If(jen.Id("inA == nil").Op("||").Id("inB == nil")).Block(jen.Return(jen.False())),
	}

	var cases []jen.Code
	_ = spi.findImplementations(iface, func(t types.Type) error {
		if _, ok := t.Underlying().(*types.Interface); ok {
			return nil
		}
		typeString := types.TypeString(t, noQualifier)
		caseBlock := jen.Case(jen.Id(typeString)).Block(
			jen.Id("b, ok := inB.").Call(jen.Id(typeString)),
			jen.If(jen.Id("!ok")).Block(jen.Return(jen.False())),
			jen.Return(compareValueType(t, jen.Id("a"), jen.Id("b"), true, spi)),
		)
		cases = append(cases, caseBlock)
		return nil
	})

	cases = append(cases,
		jen.Default().Block(
			jen.Comment("this should never happen"),
			jen.Return(jen.False()),
		))

	stmts = append(stmts, jen.Switch(jen.Id("a := inA.(type)").Block(
		cases...,
	)))

	funcDecl, funcName := e.declareFunc(t, "inA", "inB")
	e.addFunc(funcName, funcDecl.Block(stmts...))

	return nil
}

func compareValueType(t types.Type, a, b *jen.Statement, eq bool, spi generatorSPI) *jen.Statement {
	switch t.Underlying().(type) {
	case *types.Basic:
		if eq {
			return a.Op("==").Add(b)
		}
		return a.Op("!=").Add(b)
	}
	spi.addType(t)
	fcall := jen.Id("cmp").Dot(printableTypeName(t)).Call(a, b)
	if !eq {
		return jen.Op("!").Add(fcall)
	}
	return fcall
}

func (e *equalsGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
	/*
		func EqualsRefOfRefContainer(inA RefContainer, inB RefContainer, f ASTComparison) bool {
			return EqualsRefOfLeaf(inA.ASTImplementationType, inB.ASTImplementationType, f) &&
				EqualsAST(inA.ASTType, inB.ASTType, f) && inA.NotASTType == inB.NotASTType
		}
	*/

	funcDecl, funcName := e.declareFunc(t, "a", "b")
	e.addFunc(funcName, funcDecl.Block(jen.Return(compareAllStructFields(strct, spi))))

	return nil
}

func compareAllStructFields(strct *types.Struct, spi generatorSPI) jen.Code {
	var basicsPred []*jen.Statement
	var others []*jen.Statement
	for i := 0; i < strct.NumFields(); i++ {
		field := strct.Field(i)
		if field.Type().Underlying().String() == anyTypeName || strings.HasPrefix(field.Name(), "_") {
			// we can safely ignore this, we do not want ast to contain `any` types.
			continue
		}
		fieldA := jen.Id("a").Dot(field.Name())
		fieldB := jen.Id("b").Dot(field.Name())
		pred := compareValueType(field.Type(), fieldA, fieldB, true, spi)
		if _, ok := field.Type().(*types.Basic); ok {
			basicsPred = append(basicsPred, pred)
			continue
		}
		others = append(others, pred)
	}

	var ret *jen.Statement
	for _, pred := range basicsPred {
		if ret == nil {
			ret = pred
		} else {
			ret = ret.Op("&&").Line().Add(pred)
		}
	}

	for _, pred := range others {
		if ret == nil {
			ret = pred
		} else {
			ret = ret.Op("&&").Line().Add(pred)
		}
	}

	if ret == nil {
		return jen.True()
	}
	return ret
}

func (e *equalsGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
	/*
		func EqualsRefOfType(a, b *Type, f ASTComparison) *Type {
			if a == b {
				return true
			}
			if a == nil || b == nil {
				return false
			}

			// only if it is a *ColName
			if f != nil {
				return f.ColNames(a, b)
			}

			return compareAllStructFields
		}
	*/
	// func EqualsRefOfType(a,b  *Type) *Type
	funcDeclaration, funcName := e.declareFunc(t, "a", "b")
	stmts := []jen.Code{
		jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())),
		jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())),
	}

	typeString := types.TypeString(t, noQualifier)

	if _, ok := e.comparators[typeString]; ok {
		e.comparators[typeString] = t

		method := e.customComparatorField(t)
		stmts = append(stmts,
			jen.If(jen.Id("cmp").Dot(method).Op("!=").Nil()).Block(
				jen.Return(jen.Id("cmp").Dot(method).Call(jen.Id("a"), jen.Id("b"))),
			))
	}

	stmts = append(stmts, jen.Return(compareAllStructFields(strct, spi)))

	e.addFunc(funcName, funcDeclaration.Block(stmts...))
	return nil
}

func (e *equalsGen) ptrToBasicMethod(t types.Type, _ *types.Basic, spi generatorSPI) error {
	/*
		func EqualsRefOfBool(a, b *bool, f ASTComparison) bool {
			if a == b {
				return true
			}
			if a == nil || b == nil {
				return false
			}
			return *a == *b
		}
	*/
	funcDeclaration, funcName := e.declareFunc(t, "a", "b")
	stmts := []jen.Code{
		jen.If(jen.Id("a == b")).Block(jen.Return(jen.True())),
		jen.If(jen.Id("a == nil").Op("||").Id("b == nil")).Block(jen.Return(jen.False())),
		jen.Return(jen.Id("*a == *b")),
	}
	e.addFunc(funcName, funcDeclaration.Block(stmts...))
	return nil
}

func (e *equalsGen) declareFunc(t types.Type, aArg, bArg string) (*jen.Statement, string) {
	typeString := types.TypeString(t, noQualifier)
	funcName := printableTypeName(t)

	// func EqualsFunNameS(a, b <T>, f ASTComparison) bool
	return jen.Func().Params(jen.Id("cmp").Op("*").Id(Comparator)).Id(funcName).Call(jen.Id(aArg), jen.Id(bArg).Id(typeString)).Bool(), funcName
}

func (e *equalsGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
	/*
		func EqualsSliceOfRefOfLeaf(a, b []*Leaf) bool {
			if len(a) != len(b) {
				return false
			}
			for i := 0; i < len(a); i++ {
				if !EqualsRefOfLeaf(a[i], b[i]) {
					return false
				}
			}
			return false
		}
	*/

	stmts := []jen.Code{jen.If(jen.Id("len(a) != len(b)")).Block(jen.Return(jen.False())),
		jen.For(jen.Id("i := 0; i < len(a); i++")).Block(
			jen.If(compareValueType(slice.Elem(), jen.Id("a[i]"), jen.Id("b[i]"), false, spi)).Block(jen.Return(jen.False()))),
		jen.Return(jen.True()),
	}

	funcDecl, funcName := e.declareFunc(t, "a", "b")
	e.addFunc(funcName, funcDecl.Block(stmts...))
	return nil
}

func (*equalsGen) basicMethod(types.Type, *types.Basic, generatorSPI) error { return nil }
